from copy import deepcopy

from torch.utils.data import DataLoader, Subset, Dataset
from pytorch_lightning import LightningDataModule
from typing import Optional
from torchvision import transforms as transform_lib
from torchvision.datasets import MNIST
from pathlib import Path
from hydra.utils import call
from src.data.data_utils import split_subsets_train_val, split_dataset_train_val, add_attrs


class MNISTDataModule(LightningDataModule):
    """Standard MNIST, train, val, test splits and transforms.
    >>> MNISTDataModule()  # doctest: +ELLIPSIS
    <...mnist_datamodule.MNISTDataModule object at ...>
    """

    name = "mnist"

    def __init__(
            self,
            split_function,
            data_dir: str = Path("/tmp"),
            val_split: float = 0.1,
            num_workers: int = 16,
            normalize: bool = False,
            seed: int = 42,
            batch_size: int = 32,
            num_clients: int = 3,
            fair_val: bool = False,
            *args,
            **kwargs,
    ):
        """
        Args:
            data_dir: where to save/load the data
            val_split: how many of the training images to use for the validation split
            num_workers: how many workers to use for loading data
            normalize: If true applies image normalize
            seed: starting seed for RNG.
            batch_size: desired batch size.
        """
        super().__init__(*args, **kwargs)

        self.data_dir = data_dir
        self.val_split = val_split
        self.num_workers = num_workers
        self.normalize = normalize
        self.seed = seed
        self.batch_size = batch_size
        self.num_clients = num_clients
        self.fair_val = fair_val
        self.split_function = split_function
        self.datasets_train: [Subset] = ...
        self.datasets_val: [Subset] = ...
        self.test_dataset: Dataset = ...
        self.train_dataset: Dataset = ...
        self.current_client_idx = 0

    @property
    def num_classes(self):
        return 10

    def prepare_data(self):
        """Saves MNIST files to `data_dir`"""
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):
        """Split the train and valid dataset."""
        if stage == "fit":
            self.train_dataset = MNIST(
                self.data_dir, train=True,
                download=False,
                transform=self.default_transforms
            )
            if self.fair_val:
                train_subset, val_subset = split_dataset_train_val(
                    train_dataset=self.train_dataset,
                    val_split=self.val_split,
                    seed=self.seed
                )
                self.datasets_train = call(self.split_function, dataset=train_subset)
                self.datasets_val = [deepcopy(val_subset) for _ in range(self.num_clients)]
                add_attrs(self.datasets_train, self.datasets_val) # important
            else:
                subsets = call(self.split_function, dataset=self.train_dataset)
                # results is # [train1, t2, ..., tn], [vval1, v2, ..., vn]
                self.datasets_train, self.datasets_val = split_subsets_train_val(subsets, self.val_split, self.seed)

    def transfer_setup(self):
        self.train_dataset = MNIST(
            self.data_dir, train=True,
            download=False,
            transform=self.default_transforms
        )

        self.val_dataset = MNIST(
            self.data_dir, train=True,
            download=False,
            transform=self.default_transforms
        )
        extra = dict(transform=self.default_transforms) if self.default_transforms else {}
        self.test_dataset = MNIST(self.data_dir, train=False, download=False, **extra)


    def next_client(self):
        self.current_client_idx += 1
        assert self.current_client_idx < self.num_clients, "Client number shouldn't excced seleced number of clients"

    def train_dataloader(self):
        # check this: https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html#multiple-training-dataloaders
        """MNIST train set removes a subset to use for validation."""
        loader = DataLoader(
            self.datasets_train[self.current_client_idx],
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True,
        )
        return loader

    def val_dataloader(self):
        """MNIST val set uses a subset of the training set for validation."""
        loader = DataLoader(
            self.datasets_val[self.current_client_idx],
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
        )
        return loader

    def test_dataloader(self):
        """MNIST test set uses the test split."""
        extra = dict(transform=self.default_transforms) if self.default_transforms else {}
        dataset = MNIST(self.data_dir, train=False, download=False, **extra)
        loader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
        )
        return loader

    @property
    def default_transforms(self):
        if self.normalize:
            mnist_transforms = transform_lib.Compose(
                [
                    transform_lib.Resize((32, 32)),
                    transform_lib.ToTensor(),
                    transform_lib.Normalize(mean=(0.5,), std=(0.5,))
                ]
            )
        else:
            mnist_transforms = transform_lib.Compose(
                [
                    transform_lib.Resize((32, 32)),
                    transform_lib.ToTensor(),
                ]
            )
        return mnist_transforms
